{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build an MNIST Classifier Pipeline using Kubeflow and SageMaker\n",
    "\n",
    "The `mnist-classification-pipeline.py` sample runs a pipeline to train a classficiation model using Kmeans with MNIST dataset on Sagemaker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will have all required steps here and for other details like how to get source data, please check [documentation](https://github.com/kubeflow/pipelines/tree/master/samples/contrib/aws-samples/mnist-kmeans-sagemaker).\n",
    "\n",
    "\n",
    "This sample is based on the [Train a Model with a Built-in Algorithm and Deploy it](https://docs.aws.amazon.com/sagemaker/latest/dg/ex1.html).\n",
    "\n",
    "The sample trains and deploy a model based on the [MNIST dataset](http://www.deeplearning.net/tutorial/gettingstarted.html).\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install AWS Python SDK (`boto3`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Kubeflow Pipelines SDK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install https://storage.googleapis.com/ml-pipeline/release/0.1.29/kfp.tar.gz --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Restart the kernel to pick up pip installed libraries\n",
    "from IPython.core.display import HTML\n",
    "HTML(\"<script>Jupyter.notebook.kernel.restart()</script>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "AWS_REGION_AS_SLIST=!curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone | sed 's/\\(.*\\)[a-z]/\\1/'\n",
    "AWS_REGION = AWS_REGION_AS_SLIST.s\n",
    "print('Region: {}'.format(AWS_REGION))\n",
    "\n",
    "AWS_ACCOUNT_ID=boto3.client('sts').get_caller_identity().get('Account')\n",
    "print('Account ID: {}'.format(AWS_ACCOUNT_ID))\n",
    "\n",
    "S3_BUCKET='sagemaker-{}-{}'.format(AWS_REGION, AWS_ACCOUNT_ID)\n",
    "print('S3 Bucket: {}'.format(S3_BUCKET))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Copy `data` and `valid_data.csv` into your S3 bucket."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp s3://kubeflow-pipeline-data/mnist_kmeans_example/data s3://$S3_BUCKET/mnist_kmeans_example/data\n",
    "!aws s3 cp s3://kubeflow-pipeline-data/mnist_kmeans_example/input/valid_data.csv s3://$S3_BUCKET/mnist_kmeans_example/input/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Run the following command to load Kubeflow Pipelines SDK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import kfp\n",
    "from kfp import components\n",
    "from kfp import dsl\n",
    "from kfp.aws import use_aws_secret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Load reusable sagemaker components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_train_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/0ad6c28d32e2e790e6a129b7eb1de8ec59c1d45f/components/aws/sagemaker/train/component.yaml')\n",
    "sagemaker_model_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/0ad6c28d32e2e790e6a129b7eb1de8ec59c1d45f/components/aws/sagemaker/model/component.yaml')\n",
    "sagemaker_deploy_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/0ad6c28d32e2e790e6a129b7eb1de8ec59c1d45f/components/aws/sagemaker/deploy/component.yaml')\n",
    "sagemaker_batch_transform_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/0ad6c28d32e2e790e6a129b7eb1de8ec59c1d45f/components/aws/sagemaker/batch_transform/component.yaml')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Create pipeline. \n",
    "\n",
    "We will create a training job first. Once training job is done, it will persist trained model to S3. \n",
    "\n",
    "Then a job will be kicked off to create a `Model` manifest in Sagemaker. \n",
    "\n",
    "With this model, batch transformation job can use it to predict on other datasets, prediction service can create an endpoint using it.\n",
    "\n",
    "\n",
    "> Note: remember to use your **role_arn** to successfully run the job.\n",
    "\n",
    "> Note: If you use a different region, please replace `us-west-2` with your region. \n",
    "\n",
    "> Note: ECR Images for k-means algorithm\n",
    "\n",
    "|Region| ECR Image|\n",
    "|------|----------|\n",
    "|us-west-1|632365934929.dkr.ecr.us-west-1.amazonaws.com|\n",
    "|us-west-2|174872318107.dkr.ecr.us-west-2.amazonaws.com|\n",
    "|us-east-1|382416733822.dkr.ecr.us-east-1.amazonaws.com|\n",
    "|us-east-2|404615174143.dkr.ecr.us-east-2.amazonaws.com|\n",
    "|us-gov-west-1|226302683700.dkr.ecr.us-gov-west-1.amazonaws.com|\n",
    "|ap-east-1|286214385809.dkr.ecr.ap-east-1.amazonaws.com|\n",
    "|ap-northeast-1|351501993468.dkr.ecr.ap-northeast-1.amazonaws.com|\n",
    "|ap-northeast-2|835164637446.dkr.ecr.ap-northeast-2.amazonaws.com|\n",
    "|ap-south-1|991648021394.dkr.ecr.ap-south-1.amazonaws.com|\n",
    "|ap-southeast-1|475088953585.dkr.ecr.ap-southeast-1.amazonaws.com|\n",
    "|ap-southeast-2|712309505854.dkr.ecr.ap-southeast-2.amazonaws.com|\n",
    "|ca-central-1|469771592824.dkr.ecr.ca-central-1.amazonaws.com|\n",
    "|eu-central-1|664544806723.dkr.ecr.eu-central-1.amazonaws.com|\n",
    "|eu-north-1|669576153137.dkr.ecr.eu-north-1.amazonaws.com|\n",
    "|eu-west-1|438346466558.dkr.ecr.eu-west-1.amazonaws.com|\n",
    "|eu-west-2|644912444149.dkr.ecr.eu-west-2.amazonaws.com|\n",
    "|eu-west-3|749696950732.dkr.ecr.eu-west-3.amazonaws.com|\n",
    "|me-south-1|249704162688.dkr.ecr.me-south-1.amazonaws.com|\n",
    "|sa-east-1|855470959533.dkr.ecr.sa-east-1.amazonaws.com|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAGEMAKER_ROLE_ARN='arn:aws:iam::{}:role/TeamRole'.format(AWS_ACCOUNT_ID)\n",
    "\n",
    "# Configure your s3 bucket.\n",
    "S3_PIPELINE_PATH='s3://{}/mnist_kmeans_example'.format(S3_BUCKET)\n",
    "\n",
    "# TODO:  Implement the other region checks\n",
    "if AWS_REGION == 'us-west-2':\n",
    "    AWS_ECR_REGISTRY='174872318107.dkr.ecr.us-west-2.amazonaws.com'\n",
    "\n",
    "if AWS_REGION == 'us-east-1':\n",
    "    AWS_ECR_REGISTRY='382416733822.dkr.ecr.us-east-1.amazonaws.com'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dsl.pipeline(\n",
    "    name='MNIST Classification pipeline',\n",
    "    description='MNIST Classification using KMEANS in SageMaker'\n",
    ")\n",
    "def mnist_classification(region=AWS_REGION,\n",
    "    image='{}/kmeans:1'.format(AWS_ECR_REGISTRY),\n",
    "    dataset_path=S3_PIPELINE_PATH + '/data',\n",
    "    instance_type='ml.c4.8xlarge',\n",
    "    instance_count='2',\n",
    "    volume_size='50',\n",
    "    model_output_path=S3_PIPELINE_PATH + '/model',\n",
    "    batch_transform_input=S3_PIPELINE_PATH + '/input',\n",
    "    batch_transform_ouput=S3_PIPELINE_PATH + '/output',\n",
    "    role_arn=SAGEMAKER_ROLE_ARN\n",
    "    ):\n",
    "\n",
    "    training = sagemaker_train_op(\n",
    "        region=region,\n",
    "        image=image,\n",
    "        instance_type=instance_type,\n",
    "        instance_count=instance_count,\n",
    "        volume_size=volume_size,\n",
    "        dataset_path=dataset_path,\n",
    "        model_artifact_path=model_output_path,\n",
    "        role=role_arn,\n",
    "    ).apply(use_aws_secret('aws-secret', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'))\n",
    "\n",
    "    create_model = sagemaker_model_op(\n",
    "        region=region,\n",
    "        image=image,\n",
    "        model_artifact_url=training.outputs['model_artifact_url'],\n",
    "        model_name=training.outputs['job_name'],\n",
    "        role=role_arn\n",
    "    ).apply(use_aws_secret('aws-secret', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'))\n",
    "\n",
    "    prediction = sagemaker_deploy_op(\n",
    "        region=region,\n",
    "        model_name=create_model.output\n",
    "    ).apply(use_aws_secret('aws-secret', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'))\n",
    "\n",
    "    batch_transform = sagemaker_batch_transform_op(\n",
    "        region=region,\n",
    "        model_name=create_model.output,\n",
    "        input_location=batch_transform_input,\n",
    "        output_location=batch_transform_ouput\n",
    "    ).apply(use_aws_secret('aws-secret', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Compile your pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kfp.compiler.Compiler().compile(mnist_classification, 'mnist-classification-pipeline.zip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al ./mnist-classification-pipeline.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -o ./mnist-classification-pipeline.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat pipeline.yaml"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. Deploy your pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "client = kfp.Client()\n",
    "aws_experiment = client.create_experiment(name='aws')\n",
    "my_run = client.run_pipeline(aws_experiment.id, 'mnist-classification-pipeline', \n",
    "  'mnist-classification-pipeline.zip')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "_Note:  The above training job may take 5-10 minutes.  Please be patient._\n",
    "\n",
    "In the meantime, open the SageMaker Console to monitor the progress of your training job.\n",
    "\n",
    "![SageMaker Training Job Console](img/sagemaker-training-job-console.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the Name of the Deployed Prediction Endpoint\n",
    "First, we need to get the endpoint name of our newly-deployed SageMaker Prediction Endpoint."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Open AWS console and enter SageMaker service, find the endpoint name as the following picture shows.\n",
    "\n",
    "![download-pipeline](images/sm-endpoint.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make a Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _YOU MUST COPY/PASTE THE `ENDPOINT_NAME` BEFORE CONTINUING_\n",
    "Make sure to include preserve the single-quotes as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle, gzip, numpy, urllib.request, json\n",
    "from urllib.parse import urlparse\n",
    "import json\n",
    "import io\n",
    "import boto3\n",
    "\n",
    "#################################\n",
    "#################################\n",
    "# Replace ENDPOINT_NAME with the endpoint name in the SageMaker console.\n",
    "# Surround with single quotes.\n",
    "ENDPOINT_NAME= # 'Endpoint-<your-endpoint-name>'\n",
    "#################################\n",
    "#################################\n",
    "\n",
    "# Load the dataset\n",
    "urllib.request.urlretrieve(\"http://deeplearning.net/data/mnist/mnist.pkl.gz\", \"mnist.pkl.gz\")\n",
    "with gzip.open('mnist.pkl.gz', 'rb') as f:\n",
    "    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')\n",
    "\n",
    "# Simple function to create a csv from our numpy array\n",
    "def np2csv(arr):\n",
    "    csv = io.BytesIO()\n",
    "    numpy.savetxt(csv, arr, delimiter=',', fmt='%g')\n",
    "    return csv.getvalue().decode().rstrip()\n",
    "\n",
    "runtime = boto3.Session(region_name=AWS_REGION).client('sagemaker-runtime')\n",
    "\n",
    "payload = np2csv(train_set[0][30:31])\n",
    "\n",
    "response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME,\n",
    "                                   ContentType='text/csv',\n",
    "                                   Body=payload)\n",
    "result = json.loads(response['Body'].read().decode())\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up\n",
    "\n",
    "Go to Sagemaker console and delete `endpoint` and `model`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
